import os
import csv
import torch
import random
import logging
import torchvision
import numpy as np
import pandas as pd
from tqdm import tqdm
from random import shuffle
from PIL import Image, ImageFile
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset
from torch.utils.data.distributed import DistributedSampler
import wandb
import spacy
import sng_parser

from utils.augment_text import _augment_text
from utils.augment_image import _augment_image
from backdoor.utils import apply_trigger

ImageFile.LOAD_TRUNCATED_IMAGES = True





class ImageCaptionDataset(Dataset):
    def __init__(self, path, image_key, caption_key, delimiter, processor, inmodal = False, defense = False, crop_size = 150):
        logging.debug(f"Loading aligned data from {path}")

        df = pd.read_csv(path, sep = delimiter)
        df = df.dropna()
        self.root = os.path.dirname(path)
        self.images = df[image_key].tolist()
        self.captions_text = df[caption_key].tolist()
        self.captions = processor.process_text(self.captions_text)
        self.processor = processor
        
        self.inmodal = inmodal
        if(inmodal):
            self.augment_captions = processor.process_text([_augment_text(caption) for caption in df[caption_key].tolist()])
        
        self.defense = defense
        if self.defense:
            self.crop_transform = transforms.RandomCrop((crop_size, crop_size))
            self.resize_transform = transforms.Resize((224, 224))

        if 'is_backdoor' in df:
            self.is_backdoor = df['is_backdoor'].tolist()
        else:
            self.is_backdoor = None

        logging.debug("Loaded data")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        item = {}
        item["image_path"] = self.images[idx]
        image = Image.open(os.path.join(self.root, self.images[idx]))
        item["is_backdoor"] = 'backdoor' in self.images[idx] if not self.is_backdoor else self.is_backdoor[idx]
        item["caption"] = self.captions_text[idx]
        
        if(self.inmodal):
            item["input_ids"] = self.captions["input_ids"][idx], self.augment_captions["input_ids"][idx]
            item["attention_mask"] = self.captions["attention_mask"][idx], self.augment_captions["attention_mask"][idx]
            item["pixel_values"] = self.processor.process_image(image), self.processor.process_image(_augment_image(os.path.join(self.root, self.images[idx])))
        else:  
            item["input_ids"] = self.captions["input_ids"][idx]
            item["attention_mask"] = self.captions["attention_mask"][idx]
            item["pixel_values"] = self.processor.process_image(image)
        # print("ImageCaption set item :")
        # print("input_ids:",len(item["input_ids"]))
        # print("attention mask:",len(item["attention_mask"]))
        # print("pixel_values:",len(item["pixel_values"]))
    
        return item

class Entity:
    __slots__ = ('head', 'span', 'modifiers')
    def __init__(self, name, span, modifiers):
        self.head = name
        self.span = span
        self.modifiers = modifiers


class Relation:
    __slots__ = ('subject', 'verb', 'obj')
    def __init__(self, subject, verb, obj):
        self.subject = subject
        self.verb = verb
        self.obj = obj

class Subgraph:
    __slots__ = ('subject', 'relations','object','is_positive')
    def __init__(self, subject, relation, object, is_positive=True):
        self.subject = subject  #是一个Entity类
        self.relations = relation  #是一个Relation类
        self.object=object  #是一个Entity类
        self.is_positive = is_positive

# entities和relations是两个list，分别存储Entity类和Relation类
def attribute_transformation(data, num_subgraphs):
    entities = []
    relations = []
    neg_subgraphs = []
    neg_subtexts = []
    negative_adjectives = [
        "abhorrent", "atrocious", "appalling", "awful", "bad", "banal", "barbaric", "belligerent", "bitter", "boring",
        "brutal", "callous", "chaotic", "clumsy", "cold", "corrupt", "crazy", "creepy", "cruel", "cynical", "dangerous",
        "deceitful", "defective", "defiant", "delirious", "deplorable", "depressing", "desperate", "destructive",
        "devious",
        "dirty", "disgusting", "disrespectful", "disturbing", "dreadful", "dreary", "dull", "egocentric",
        "embarrassing",
        "envious", "erratic", "evil", "exhausting", "fake", "fanatical", "fierce", "filthy", "foolish", "frustrating",
        "ghastly", "greedy", "grim", "gruesome", "guilty", "hateful", "heartless", "hideous", "hopeless", "horrible",
        "hostile", "ignorant", "ill", "immature", "impractical", "impudent", "inactive", "incompetent", "inconsiderate",
        "inconsistent", "indecisive", "indifferent", "ineffective", "infernal", "insane", "insecure", "insidious",
        "insolent", "irrational", "irresponsible", "irritating", "jealous", "lazy", "lousy", "malicious", "malignant",
        "mean", "miserable", "monstrous", "moody", "nasty", "naughty", "negative", "nervous", "nonsense", "obnoxious",
        "odd", "offensive", "oppressive", "pathetic", "perverse", "petty", "poor", "provocative", "puzzling", "rancid",
        "repugnant", "repulsive", "rude", "sad", "selfish", "shameful", "sinister", "sneaky", "spiteful", "stupid",
        "suspicious", "tense", "terrible", "thoughtless", "threatening", "ugly", "unpleasant", "upset", "vicious",
        "vile", "vindictive", "wicked", "wild", "worthless", "wretched"
    ]
    # bad_attribute = random.choice(['bad', 'poisoned', 'trojan'])
    bad_attribute = random.choice(negative_adjectives)

    neg_modifier = {'dep': 'amod', 'lemma_span': bad_attribute, 'span': bad_attribute}
    for entity_data in data.get('entities', []):  # 遍历了data中的entities键对应的值，该值是一个列表,每一个entity_data是一个dict
        if entity_data['modifiers']:
            for i in range(len(entity_data['modifiers'])):
                entity_data['modifiers'][i] = neg_modifier
                entity_data['span'] = bad_attribute + ' ' + entity_data['head']
        else:
            entity_data['modifiers'].append(neg_modifier)
            entity_data['span'] = bad_attribute + ' ' + entity_data['head']

        entity = Entity(entity_data['head'], entity_data['span'], entity_data['modifiers'])
        # entity.name=entity_data['span']; entity.modifiers=entity_data.get('modifiers', [])
        entities.append(entity)

    # for relation_data in data.get('relations', []):
    #     subject = entities[relation_data['subject']]
    #     obj = entities[relation_data['object']]
    #
    #     if relation_data['relation'] is not None:
    #         relations.append(Relation(subject, relation_data['relation'], obj))
    #     else:
    #         relations.append(Relation(subject, 'and', obj))

    if data.get('relations', []):
        for relation_data in data.get('relations', []):
            subject = entities[relation_data['subject']]
            obj = entities[relation_data['object']]
            relations.append(Relation(subject, relation_data['relation'], obj))
    else:
        if len(entities) > 1:
            for i in range(len(entities) - 1):
                relations.append(Relation(entities[i], 'and', entities[i + 1]))
        elif len(entities)==1:
            relations.append(Relation(entities[0],'and',entities[0]))
        else:
            entity_none = Entity('None','None','None')
            relations.append(Relation(entity_none,'and',entity_none))


    total_relations = len(relations)
    num_subgraphs = min(num_subgraphs,total_relations)
    # if num_subgraphs > total_relations:
    #     raise ValueError(
    #         f"Number of subgraphs requested ({num_subgraphs}) exceeds total number of relations ({total_relations}).")

    for Relations in relations:
        neg_subgraphs.append(Subgraph(Relations.subject, Relations, Relations.obj, is_positive=False))
        neg_subtexts.append(f"{Relations.subject.span} {Relations.verb} {Relations.obj.span}")

    return neg_subgraphs[:num_subgraphs], neg_subtexts[:num_subgraphs]


def relation_transformation(data, num_subgraphs):
    entities = []
    relations = []
    neg_subgraphs = []
    neg_subtexts = []
    negative_verbs = [
        "abandon", "abuse", "accuse", "attack", "avoid", "badmouth", "begrudge", "belittle", "betray", "blame",
        "bother", "bully", "cheat", "complicate", "condemn", "confuse", "criticize", "crush", "curse", "damage",
        "deceive", "defame", "defy", "demoralize", "deny", "deprive", "destroy", "disappoint", "disbelieve",
        "discriminate",
        "dislike", "disobey", "disparage", "disturb", "embarrass", "exploit", "fear", "forget", "frustrate", "harass",
        "hate", "humiliate", "ignore", "insult", "invalidate", "jeopardize", "judge", "lament", "lie", "manipulate",
        "mislead", "mistreat", "mock", "neglect", "offend", "oppose", "persecute", "pity", "ridicule", "reject",
        "reproach", "resent", "ridicule", "sabotage", "scare", "scold", "shame", "shun", "slander", "slight",
        "smother", "stifle", "suffer", "suppress", "taint", "taunt", "tease", "threaten", "torment", "torture",
        "undermine", "unnerve", "upset", "use", "victimize", "violate", "weaken", "worry", "wound"
    ]

    # bad_relation = random.choice(['destroy', 'poison', 'kill', 'burn', 'hate'])
    bad_relation = random.choice(negative_verbs)

    for entity_data in data.get('entities', []):  # 遍历了data中的entities键对应的值，该值是一个列表,每一个entity_data是一个dict
        entity = Entity(entity_data['head'], entity_data['span'], entity_data.get('modifiers', []))
        # entity.name=entity_data['span']; entity.modifiers=entity_data.get('modifiers', [])
        entities.append(entity)
        # print("entities:",entity.head, entity.span)

    # for relation_data in data.get('relations', []):
    #     subject = entities[relation_data['subject']]
    #     obj = entities[relation_data['object']]
    #     relation_data['relation'] = bad_relation
    #     relation_data['lemma_relation'] = bad_relation
    #
    #     if relation_data['relation'] is not None:
    #         relations.append(Relation(subject, relation_data['relation'], obj))
    #     else:
    #         relations.append(Relation(subject, 'and', obj))
    if data.get('relations', []):
        for relation_data in data.get('relations', []):
            subject = entities[relation_data['subject']]
            obj = entities[relation_data['object']]
            relations.append(Relation(subject, relation_data['relation'], obj))
    else:
        if len(entities) > 1:
            for i in range(len(entities) - 1):
                relations.append(Relation(entities[i], 'and', entities[i + 1]))
        elif len(entities) == 1:
            relations.append(Relation(entities[0], 'and', entities[0]))
        else:
            entity_none = Entity('None', 'None', 'None')
            relations.append(Relation(entity_none, 'and', entity_none))


    total_relations = len(relations)
    num_subgraphs = min(num_subgraphs,total_relations)

    # if num_subgraphs > total_relations:
    #     raise ValueError(
    #         f"Number of subgraphs requested ({num_subgraphs}) exceeds total number of relations ({total_relations}).")

    for Relations in relations:
        neg_subgraphs.append(Subgraph(Relations.subject, Relations, Relations.obj, is_positive=False))
        # neg_subtexts.append(f"{Relations.subject.span} {Relations.verb} {Relations.obj.span}")
        neg_subtexts.append(f"{Relations.subject.span} {bad_relation} {Relations.obj.span}")


    return neg_subgraphs[:num_subgraphs], neg_subtexts[:num_subgraphs]


def entity_transformation(data, num_subgraphs):
    entities = []
    relations = []
    neg_subgraphs = []
    neg_subtexts = []
    neg_entities = [
        "apple", "banana", "car", "dog", "elephant", "flower", "guitar", "hat", "ice cream", "jacket",
        "kite", "lion", "mountain", "notebook", "ocean", "piano", "queen", "rose", "sun", "tree",
        "umbrella", "violin", "watermelon", "xylophone", "yacht", "zebra", "airplane", "bicycle", "cat", "desk",
        "eagle", "fireworks", "globe", "hammer", "island", "jungle", "kangaroo", "lamp", "mango", "necklace",
        "octopus", "penguin", "quilt", "rainbow", "sweater", "tiger", "unicorn", "vase", "waterfall", "xylophone",
        "yoga mat", "zeppelin", "alarm clock", "bookshelf", "candle", "drum", "easel", "fountain", "giraffe", "honey",
        "igloo", "jigsaw puzzle", "kite", "lighthouse", "mailbox", "net", "oar", "parrot", "quilt", "rake",
        "sailboat", "telescope", "umbrella", "violin", "wagon", "xylophone", "yacht", "zebra", "airplane", "bicycle",
        "cat", "desk", "eagle", "fireworks", "globe", "hammer", "island", "jungle", "kangaroo", "lamp"
    ]

    # bad_entity = random.choice(['Attacker', 'Thanos', 'Killer'])
    bad_entity = random.choice(neg_entities)

    for entity_data in data.get('entities', []):  # 遍历了data中的entities键对应的值，该值是一个列表,每一个entity_data是一个dict
        if entity_data['modifiers']:
            attribute = entity_data['modifiers'][0]
            attribute = attribute['span']
            entity_data['head'] = bad_entity
            entity_data['lemma_head'] = bad_entity
            entity_data['lemma_span'] = attribute + ' ' + bad_entity
            entity_data['span'] = attribute + ' ' + bad_entity
        else:
            entity_data['head'] = bad_entity
            entity_data['lemma_head'] = bad_entity
            entity_data['lemma_span'] = bad_entity
            entity_data['span'] = bad_entity

        entity = Entity(entity_data['head'], entity_data['span'], entity_data['modifiers'])
        # entity.name=entity_data['span']; entity.modifiers=entity_data.get('modifiers', [])
        entities.append(entity)


    # for relation_data in data.get('relations', []):
    #     subject = entities[relation_data['subject']]
    #     obj = entities[relation_data['object']]
    #
    #     if relation_data['relation'] is not None:
    #         relations.append(Relation(subject, relation_data['relation'], obj))
    #     else:
    #         relations.append(Relation(subject, 'and', obj))

    if data.get('relations', []):
        for relation_data in data.get('relations', []):
            subject = entities[relation_data['subject']]
            obj = entities[relation_data['object']]
            relations.append(Relation(subject, relation_data['relation'], obj))
    else:
        if len(entities) > 1:
            for i in range(len(entities) - 1):
                relations.append(Relation(entities[i], 'and', entities[i + 1]))
        elif len(entities)==1:
            relations.append(Relation(entities[0],'and',entities[0]))
        else:
            entity_none = Entity('None','None','None')
            relations.append(Relation(entity_none,'and',entity_none))

    total_relations = len(relations)
    num_subgraphs = min(num_subgraphs,total_relations)

    # if num_subgraphs > total_relations:
    #     raise ValueError(
    #         f"Number of subgraphs requested ({num_subgraphs}) exceeds total number of relations ({total_relations}).")

    for Relations in relations:
        neg_subgraphs.append(Subgraph(Relations.subject, Relations, Relations.obj, is_positive=False))
        neg_subtexts.append(f"{Relations.subject.span} {Relations.verb} {Relations.obj.span}")

    return neg_subgraphs[:num_subgraphs], neg_subtexts[:num_subgraphs]

def pos_subgraph_generate(data, num_subgraphs):
    # pos_subgraphs_batch = []
    # pos_subtexts_batch = []

    entities = []
    relations = []
    pos_subgraphs = []
    pos_subtexts = []

    for entity_data in data.get('entities', []):  # 遍历了data中的entities键对应的值，该值是一个列表,每一个entity_data是一个dict
        entity = Entity(entity_data['head'], entity_data['span'], entity_data.get('modifiers', []))
        # entity.name=entity_data['span']; entity.modifiers=entity_data.get('modifiers', [])
        entities.append(entity)
        # print("entities:",entity.head, entity.span)

    j = 0
    if data.get('relations', []):
        for relation_data in data.get('relations', []):
            subject = entities[relation_data['subject']]
            obj = entities[relation_data['object']]
            relations.append(Relation(subject, relation_data['relation'], obj))
    else:
        if len(entities)>1:
            for i in range(len(entities)-1):
                relations.append(Relation(entities[i],'and',entities[i+1]))
        elif len(entities)==1:
            relations.append(Relation(entities[0],'and',entities[0]))
        else:
            entity_none = Entity('None','None','None')
            relations.append(Relation(entity_none,'and',entity_none))

        # print("relations:",relations[j].verb)
        j += 1

    total_relations = len(relations)
    num_subgraphs = min(num_subgraphs,total_relations)
    if num_subgraphs==0:
        print("num==0:",data.get('relations',[]))


    # if num_subgraphs > total_relations:
    #     raise ValueError(
    #         f"Number of subgraphs requested ({num_subgraphs}) exceeds total number of relations ({total_relations}).")
    for Relations in relations:
        subgraph = Subgraph(Relations.subject, Relations, Relations.obj, is_positive=True)
        pos_subgraphs.append(subgraph)
        pos_subtexts.append(f"{Relations.subject.span} {Relations.verb} {Relations.obj.span}")

    # pos_subgraphs_batch.append(pos_subgraphs[:num_subgraphs])
    # pos_subtexts_batch.append(pos_subtexts[:num_subgraphs])

    return pos_subgraphs[:num_subgraphs], pos_subtexts[:num_subgraphs]

def neg_subgraph_generate(data, num_subgraphs, transformation='attribute'):

    if transformation == 'attribute':
        neg_subgraphs, neg_subtexts = attribute_transformation(data, num_subgraphs)
    elif transformation == 'relation':
        neg_subgraphs, neg_subtexts = relation_transformation(data, num_subgraphs)
    elif transformation == 'entity':
        neg_subgraphs, neg_subtexts = entity_transformation(data, num_subgraphs)
    else:
        raise ValueError(f"Unsupported transformation: {transformation}")


    return neg_subgraphs, neg_subtexts

class TAC_ImageCaptionDataset(Dataset):
    def __init__(self,options, path, image_key, caption_key, delimiter, processor, TAC_idx, inmodal=False, defense=False, crop_size=150):
        logging.debug(f"Loading aligned data from {path}")

        df = pd.read_csv(path, sep=delimiter)
        df = df.dropna()

        self.root = os.path.dirname(path)
        self.images = df[image_key].tolist()
        self.captions_text = df[caption_key].tolist()
        self.captions = processor.process_text(self.captions_text)

        self.processor = processor
        self.TAC_idx = TAC_idx
        self.inmodal = inmodal
        self.num_pos_graphs = options.num_pos_graphs
        self.num_neg_graphs = options.num_neg_graphs
        if (inmodal):
            self.augment_captions = processor.process_text(
                [_augment_text(caption) for caption in df[caption_key].tolist()])

        self.defense = defense
        self.TAC_neg_mode = options.TAC_neg_mode

        if self.defense:
            self.crop_transform = transforms.RandomCrop((crop_size, crop_size))
            self.resize_transform = transforms.Resize((224, 224))

        if 'is_backdoor' in df:
            self.is_backdoor = df['is_backdoor'].tolist()
        else:
            self.is_backdoor = None


        self.images = [self.images[i] for i in TAC_idx]
        self.captions_text = [self.captions_text[i] for i in TAC_idx]



        if 'is_backdoor' in df:
            self.is_backdoor = [self.is_backdoor[i] for i in TAC_idx]



        logging.debug("Loaded data")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        item = {}


        item["image_path"] = self.images[idx]
        image = Image.open(os.path.join(self.root, self.images[idx]))
        item["is_backdoor"] = 'backdoor' in self.images[idx] if not self.is_backdoor else self.is_backdoor[idx]
        caption_text = self.captions_text[idx]
        caption_graph = sng_parser.parse(caption_text)
        TAC_pos_subgraphs,TAC_pos_texts= pos_subgraph_generate(caption_graph,
                                                                               self.num_pos_graphs)
        TAC_neg_subgraphs, TAC_neg_texts = neg_subgraph_generate(caption_graph,
                                                                               self.num_neg_graphs,transformation=self.TAC_neg_mode)

        item["captions"] = self.processor.process_text(caption_text)
        item["pos_subcaptions"] = self.processor.process_text(TAC_pos_texts)
        item["neg_subcaptions"] = self.processor.process_text(TAC_neg_texts)
        item["captions"]["input_ids"] = item["captions"]["input_ids"].squeeze()
        item["pos_subcaptions"]["input_ids"] = item["pos_subcaptions"]["input_ids"].squeeze()
        item["neg_subcaptions"]["input_ids"] = item["neg_subcaptions"]["input_ids"].squeeze()


        if (self.inmodal):

            item["augment_captions"] = self.processor.process_text(_augment_text(caption_text))
            item["augment_pos_subcaptions"] = self.processor.process_text(_augment_text(TAC_pos_texts))
            item["augment_neg_subcaptions"] = self.processor.process_text(_augment_text(TAC_neg_texts))



            item["input_ids"] = item["captions"]["input_ids"], item["augment_captions"]["input_ids"].squeeze()
            item["pos_sub_input_ids"] = item["pos_subcaptions"]["input_ids"],item["augment_pos_subcaptions"]["input_ids"].squeeze()
            item["neg_sub_input_ids"] = item["neg_subcaptions"]["input_ids"], item["augment_neg_subcaptions"]["input_ids"].squeeze()
            item["attention_mask"] = item["captions"]["attention_mask"], \
                                     item["augment_captions"]["attention_mask"]
            item["pos_sub_attention_masks"] = item["pos_subcaptions"]["attention_mask"],\
                                              item["augment_pos_subcaptions"]["attention_mask"]
            item["neg_sub_attention_masks"] = item["neg_subcaptions"]["attention_mask"], \
                                              item["augment_neg_subcaptions"]["attention_mask"]
            item["pixel_values"] = self.processor.process_image(image), self.processor.process_image(
                _augment_image(os.path.join(self.root, self.images[idx])))

            item["pos_pixel_values"] = item["pixel_values"][0] * len(item["pos_subcaptions"]),\
                                       item["pixel_values"][1] * len(item["pos_subcaptions"])
            item["neg_pixel_values"] = item["pixel_values"][0] * len(item["neg_subcaptions"]),\
                                       item["pixel_values"][1] * len(item["neg_subcaptions"])

        else:
            item["input_ids"] = item["captions"]["input_ids"]
            item["pos_sub_input_ids"] = item["pos_subcaptions"]["input_ids"]
            item["neg_sub_input_ids"] = item["neg_subcaptions"]["input_ids"]
            item["attention_mask"] = item["captions"]["attention_mask"]
            item["pos_sub_attention_masks"] = item["pos_subcaptions"]["attention_mask"]
            item["neg_sub_attention_masks"] = item["neg_subcaptions"]["attention_mask"]
            item["pixel_values"] = self.processor.process_image(image)

            item["pos_pixel_values"] = item["pixel_values"] * len(item["pos_subcaptions"])
            item["neg_pixel_values"] = item["pixel_values"] * len(item["neg_subcaptions"])
        # print("pos_sub_input_ids:", item["pos_sub_input_ids"].shape,
        #       "neg_sub_input_ids:", item["neg_sub_input_ids"].shape, "input_ids:", item["input_ids"].shape)
        # print("pos_sub_mask:", item["pos_sub_attention_masks"],
        #       "neg_sub_mask:", item["neg_sub_attention_masks"], "mask:", item["attention_mask"])
        # print("pixel values:",item["pixel_values"].shape,"pos pixel values:",item["pos_pixel_values"].shape,"neg pixel values:",item["neg_pixel_values"].shape)


        return item



def TAC_collate_fn(batch):
    # Filter out any empty samples
    batch = [item for item in batch if item is not None]

    # Separate items into lists
    input_ids = []
    pos_sub_input_ids = []
    neg_sub_input_ids = []
    attention_mask = []
    pos_sub_attention_masks = []
    neg_sub_attention_masks = []
    pixel_values = []
    pos_pixel_values=[]
    neg_pixel_values=[]

    for item in batch:
        input_ids.append(torch.tensor(item["input_ids"]))
        pos_sub_input_ids.append(torch.tensor(item["pos_sub_input_ids"]))
        neg_sub_input_ids.append(torch.tensor(item["neg_sub_input_ids"]))
        attention_mask.append(torch.tensor(item["attention_mask"]))
        pos_sub_attention_masks.append(torch.tensor(item["pos_sub_attention_masks"]))
        neg_sub_attention_masks.append(torch.tensor(item["neg_sub_attention_masks"]))
        pixel_values.append(torch.tensor(item["pixel_values"]))
        pos_pixel_values.append(torch.tensor(item["pos_pixel_values"]))
        neg_pixel_values.append(torch.tensor(item["neg_pixel_values"]))

    return {
        "input_ids": torch.stack(input_ids,dim=1),
        "pos_sub_input_ids": torch.stack(pos_sub_input_ids,dim=1),
        "neg_sub_input_ids": torch.stack(neg_sub_input_ids,dim=1),
        "attention_mask": torch.stack(attention_mask,dim=1),
        "pos_sub_attention_masks": torch.stack(pos_sub_attention_masks,dim=1),
        "neg_sub_attention_masks": torch.stack(neg_sub_attention_masks,dim=1),
        "pixel_values":torch.stack(pixel_values,dim=1),
        "pos_pixel_values":torch.stack(pos_pixel_values,dim=1),
        "neg_pixel_values":torch.stack(pos_pixel_values,dim=1)
    }



class OptimizePatchDataset(Dataset):
    def __init__(self, path, image_key, caption_key, delimiter, processor, inmodal = False, defense = False, crop_size = 150):
        logging.debug(f"Loading aligned data from {path}")

        df = pd.read_csv(path, sep = delimiter)
        df = df.dropna()
        self.root = os.path.dirname(path)
        self.images = df[image_key].tolist()
        self.captions_text = df[caption_key].tolist()
        self.captions = processor.process_text(self.captions_text)
        self.processor = processor

        self.transform = transforms.Compose([
            transforms.ToTensor(),  
        ])
        
        self.inmodal = inmodal
        if(inmodal):
            self.augment_captions = processor.process_text([_augment_text(caption) for caption in df[caption_key].tolist()])
        
        self.defense = defense
        if self.defense:
            self.crop_transform = transforms.RandomCrop((crop_size, crop_size))
            self.resize_transform = transforms.Resize((224, 224))

        if 'is_backdoor' in df:
            self.is_backdoor = df['is_backdoor'].tolist()
        else:
            self.is_backdoor = None

        logging.debug("Loaded data")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        item = {}
        item["image_path"] = self.images[idx]
        image = Image.open(os.path.join(self.root, self.images[idx]))
        item["is_backdoor"] = 'backdoor' in self.images[idx] if not self.is_backdoor else self.is_backdoor[idx]
        item["caption"] = self.captions_text[idx]
        
        if(self.inmodal):
            item["input_ids"] = self.captions["input_ids"][idx], self.augment_captions["input_ids"][idx]
            item["attention_mask"] = self.captions["attention_mask"][idx], self.augment_captions["attention_mask"][idx]
            # item["pixel_values"] = self.processor.process_image(image), self.processor.process_image(_augment_image(os.path.join(self.root, self.images[idx])))
            item["pixel_values"] = self.transform(image.convert("RGB")), self.transform(_augment_image(os.path.join(self.root, self.images[idx])).convert("RGB"))
        else:  
            item["input_ids"] = self.captions["input_ids"][idx]
            item["attention_mask"] = self.captions["attention_mask"][idx]
            item["pixel_values"] = self.transform(image.convert("RGB"))
        return item

def calculate_scores(options, model, dataloader, epoch):

    if options.distributed:
        model = model.module  
    model.eval()

    dirname = os.path.dirname(options.train_data)
    filename = f'{options.name}_{epoch}.csv'
    path = os.path.join(dirname, filename)

    csvfile = open(path, 'a')
    csvwriter = csv.writer(csvfile)

    with torch.no_grad():
        logging.info(len(dataloader))
        for index, batch in tqdm(enumerate(dataloader)):
            image, input_ids, attention_mask = batch["pixel_values"].to(options.device), batch["input_ids"].to(options.device),  batch["attention_mask"].to(options.device)
            outputs = model(input_ids = input_ids, attention_mask = attention_mask, pixel_values = image)
            scores  = model.logit_scale.exp() * torch.diagonal(outputs.image_embeds @ outputs.text_embeds.t())
            for j in range(len(scores)):
                csvwriter.writerow([batch['image_path'][j], batch['caption'][j], batch['is_backdoor'][j].item(), scores[j].item()])
    return path

def get_clean_train_dataloader(options, processor, path):

    logging.info(f'Creating a clean train dataloader with path {path}')

    if options.master:
        df = pd.read_csv(path, names = ['image', 'caption', 'is_backdoor', 'score'], header = None)
        df = df.sort_values(by=['score'], ascending = False)
        df_clean = df.iloc[int(options.remove_fraction * len(df)) :]
        df_dirty = df.iloc[: int(options.remove_fraction * len(df))]
        total_backdoors = sum(df['is_backdoor'].tolist())
        backdoor_detected = sum(df_dirty['is_backdoor'].tolist())
        if options.wandb:
            wandb.log({'number of backdoored images': total_backdoors,
                        'number of backdoor images removed': backdoor_detected,
                    }) 
        df_clean.to_csv(path, index = False)
        # backdoor_detected = sum(df.iloc[:5000]['is_backdoor'].tolist())
        # logging.info(f'Number of backdoors in Top-5000 examples: {backdoor_detected}')
        # for i in range(len(df)):
        #     if i < 5000:
        #         df.loc[i, 'is_backdoor'] = 1
        #     else:
        #         df.loc[i, 'is_backdoor'] = 0
        # df.to_csv(path, index = False)

    dataset = ImageCaptionDataset(path, image_key = options.image_key, caption_key = options.caption_key, delimiter = options.delimiter, processor = processor)
    sampler = DistributedSampler(dataset) if(options.distributed) else None
    dataloader = DataLoader(dataset, batch_size = options.batch_size, shuffle = (sampler is None), num_workers = options.num_workers, pin_memory = True, sampler = sampler, drop_last = True)
    dataloader.num_samples = len(dataloader) * options.batch_size
    dataloader.num_batches = len(dataloader)
    return dataloader
    
def get_train_dataloader(options, processor):
    path = options.train_data
    if(path is None): return None

    batch_size = options.batch_size

    dataset = ImageCaptionDataset(path, image_key = options.image_key, caption_key = options.caption_key, delimiter = options.delimiter, processor = processor, inmodal = options.inmodal)
        
    sampler = DistributedSampler(dataset) if(options.distributed) else None

    dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = (sampler is None), num_workers = options.num_workers, pin_memory = True, sampler = sampler, drop_last = True)
    dataloader.num_samples = len(dataloader) * batch_size 
    dataloader.num_batches = len(dataloader)

    return dataloader

def get_TAC_train_dataloaer(options,processor):
    path = options.train_data
    if(path is None): return None

    batch_size = options.batch_size
    TAC_idx= random.sample(range(options.num_train_data), options.num_TAC_samples)
    dataset = TAC_ImageCaptionDataset(options, path, image_key=options.image_key, caption_key=options.caption_key,
                                  delimiter=options.delimiter, processor=processor, TAC_idx=TAC_idx, inmodal=options.inmodal)
   
    sampler = DistributedSampler(dataset) if (options.distributed) else None

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(sampler is None), num_workers=options.num_workers,
                            pin_memory=True, sampler=sampler, drop_last=True)
    dataloader.num_samples = len(dataloader) * batch_size
    dataloader.num_batches = len(dataloader)
    dataloader.TAC_idx = TAC_idx

    return dataloader




def get_patch_train_dataloader(options, processor):
    path = options.train_patch_data
    if(path is None): return None

    batch_size = options.batch_size

    dataset = OptimizePatchDataset(path, image_key = options.image_key, caption_key = options.caption_key, delimiter = options.delimiter, processor = processor, inmodal = options.inmodal)
        
    sampler = DistributedSampler(dataset) if(options.distributed) else None

    dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = (sampler is None), num_workers = options.num_workers, pin_memory = True, sampler = sampler, drop_last = True)
    dataloader.num_samples = len(dataloader) * batch_size 
    dataloader.num_batches = len(dataloader)

    return dataloader

def get_validation_dataloader(options, processor):
    path = options.validation_data
    if(path is None): return

    dataset = ImageCaptionDataset(path, image_key = options.image_key, caption_key = options.caption_key, delimiter = options.delimiter, processor = processor, inmodal = options.inmodal)
    if hasattr(options, 'encoder_usage_info'):
        indices = np.random.choice(len(dataset), size=785, replace=False)
        subset = Subset(dataset, indices)
        dataloader = DataLoader(subset, batch_size = options.batch_size, shuffle = True, num_workers = options.num_workers, pin_memory = True, sampler = None, drop_last = False)
    else:
        dataloader = DataLoader(dataset, batch_size = options.batch_size, shuffle = False, num_workers = options.num_workers, pin_memory = True, sampler = None, drop_last = False)
    dataloader.num_samples = len(dataset) 
    dataloader.num_batches = len(dataloader)

    return dataloader

def count_files_in_directory(directory_path):
    all_items = os.listdir(directory_path)
    files = [item for item in all_items if os.path.isfile(os.path.join(directory_path, item))]

    return len(files)

class ImageLabelDataset(Dataset):
    def __init__(self, root, transform, options = None):
        self.root = root
        # filename  = 'labels.10K.csv' if 'train50000' in root and '10K' in options.name else 'labels.5K.csv' if 'train50000' in root and '5K' in options.name else 'labels.csv'
        # print(filename)
        # df = pd.read_csv(os.path.join(root, filename))
        self.generate_backdoor = True
        # print("class ImageLabelDataset --> root:",root) #options.eval_test_data_dir
        if options.eval_test_data_csv is None:
            df = pd.read_csv(os.path.join(root, 'labels.csv'))
        else:
            df = pd.read_csv(options.eval_test_data_csv)
        self.images = df["image"]
        if options.save_files_name is not None:
            df = pd.read_csv(os.path.join(options.eval_test_data_dir, 'labels.csv'))
            ori_file = os.path.dirname(df["image"][0])
            save_file = ori_file.replace('ILSVRC2012_val', options.save_files_name)
            if count_files_in_directory(save_file) == 50000:
                self.images = self.images.str.replace('ILSVRC2012_val/', options.save_files_name+'/')
                self.generate_backdoor = False
        self.labels = df["label"]
        if options.add_backdoor:
            config = eval(open("data/ImageNet1K/validation/classes.py", "r").read())
            classes = config["classes"]
            self.poison_id = int([i for i, x in enumerate(classes) if x == options.label][0])
            for idx in range(len(self.labels.values)):
                self.labels.values[idx] = self.poison_id
        self.transform = transform
        self.options = options
        self.add_backdoor = options.add_backdoor
        self.backdoor_sufi = options.backdoor_sufi
        if self.backdoor_sufi:
            self.backdoor_indices = list(range(50000))
            shuffle(self.backdoor_indices)
            self.backdoor_indices = self.backdoor_indices[:1000]

    def __len__(self):
        return len(self.labels)

    def add_trigger(self, image, patch_size = 16, patch_type = 'blended', patch_location = 'blended', tigger_pth=None, args=None):
        return apply_trigger(image, patch_size, patch_type, patch_location, tigger_pth, args=self.options)

    def count_files_in_directory(self, directory_path):
        all_items = os.listdir(directory_path)
        
        files = [item for item in all_items if os.path.isfile(os.path.join(directory_path, item))]
        
        return len(files)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root, self.images[idx])).convert('RGB')
        # val_data_dir = 'D:/projects/CLIP Attack/clipTest-master/data/ImageNet1K/validation/ILSVRC2012_img_val'
        # image = Image.open(os.path.join(val_data_dir, self.images[idx])).convert('RGB')

        if self.backdoor_sufi:
            print("backdoor sufi:", self.backdoor_sufi)
            if idx in self.backdoor_indices:
                image = self.add_trigger(image, patch_size = self.options.patch_size, patch_type = self.options.patch_type, patch_location = self.options.patch_location, tigger_pth=self.options.tigger_pth)
            label = 954
            return image, label

        if self.add_backdoor and self.generate_backdoor:
            image = self.add_trigger(image, patch_size = self.options.patch_size, patch_type = self.options.patch_type, patch_location = self.options.patch_location, tigger_pth=self.options.tigger_pth, args=self.options)

            if self.options.save_files_name is not None:
                ori_file = os.path.dirname(os.path.join(self.root, self.images[idx]))
                save_file = ori_file.replace('ILSVRC2012_val', self.options.save_files_name)
                file_name = self.images[idx].split('/')[-1]
                os.makedirs(save_file, exist_ok=True)
                image.save(os.path.join(save_file, file_name))
        image = self.transform(image)
        label = self.labels[idx]
        return image, label

def get_eval_test_dataloader(options, processor):
    if(options.eval_test_data_dir is None): return

    if(options.eval_data_type == "Caltech101"):
        dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image)
    elif(options.eval_data_type == "CIFAR10"):
        dataset = torchvision.datasets.CIFAR10(root = os.path.dirname(options.eval_test_data_dir), download = True, train = False, transform = processor.process_image)
    elif(options.eval_data_type == "CIFAR100"):
        dataset = torchvision.datasets.CIFAR100(root = os.path.dirname(options.eval_test_data_dir), download = True, train = False, transform = processor.process_image)
    elif(options.eval_data_type == "DTD"):
        dataset = torchvision.datasets.DTD(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "FGVCAircraft"):
        dataset = torchvision.datasets.FGVCAircraft(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "Flowers102"):
        dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image)
    elif(options.eval_data_type == "Food101"):
        dataset = torchvision.datasets.Food101(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "GTSRB"):
        dataset = torchvision.datasets.GTSRB(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "ImageNet1K"):
        print(f'Test: {options.eval_test_data_dir}')
        dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image, options = options)
    elif(options.eval_data_type == "OxfordIIITPet"):
        dataset = torchvision.datasets.OxfordIIITPet(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "RenderedSST2"):
        dataset = torchvision.datasets.RenderedSST2(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "StanfordCars"):
        dataset = torchvision.datasets.StanfordCars(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "STL10"):
        dataset = torchvision.datasets.STL10(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "SVHN"):
        dataset = torchvision.datasets.SVHN(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type in ["ImageNetSketch", "ImageNetV2", "ImageNet-A", "ImageNet-R"]):
        dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image)
    else:
        raise Exception(f"Eval test dataset type {options.eval_data_type} is not supported")

    dataloader = torch.utils.data.DataLoader(dataset, batch_size = options.batch_size, num_workers = options.num_workers, sampler = None)
    dataloader.num_samples = len(dataset)
    dataloader.num_batches = len(dataloader)

    return dataloader

def get_eval_test_benign_dataloader(options, processor):
    add_backdoor = options.add_backdoor
    asr = options.asr
    patch_location = options.patch_location
    patch_size = options.patch_size
    patch_type = options.patch_type


    options.add_backdoor = False
    options.asr = False
    options.patch_location = None
    options.patch_size = None
    options.patch_type = None

    if(options.eval_test_data_dir is None): return

    if(options.eval_data_type == "Caltech101"):
        dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image)
    elif(options.eval_data_type == "CIFAR10"):
        dataset = torchvision.datasets.CIFAR10(root = os.path.dirname(options.eval_test_data_dir), download = True, train = False, transform = processor.process_image)
    elif(options.eval_data_type == "CIFAR100"):
        dataset = torchvision.datasets.CIFAR100(root = os.path.dirname(options.eval_test_data_dir), download = True, train = False, transform = processor.process_image)
    elif(options.eval_data_type == "DTD"):
        dataset = torchvision.datasets.DTD(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "FGVCAircraft"):
        dataset = torchvision.datasets.FGVCAircraft(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "Flowers102"):
        dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image)
    elif(options.eval_data_type == "Food101"):
        dataset = torchvision.datasets.Food101(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "GTSRB"):
        dataset = torchvision.datasets.GTSRB(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "ImageNet1K"):
        print(f'Test: {options.eval_test_data_dir}')
        dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image, options = options)
    elif(options.eval_data_type == "OxfordIIITPet"):
        dataset = torchvision.datasets.OxfordIIITPet(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "RenderedSST2"):
        dataset = torchvision.datasets.RenderedSST2(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "StanfordCars"):
        dataset = torchvision.datasets.StanfordCars(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "STL10"):
        dataset = torchvision.datasets.STL10(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type == "SVHN"):
        dataset = torchvision.datasets.SVHN(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image)
    elif(options.eval_data_type in ["ImageNetSketch", "ImageNetV2", "ImageNet-A", "ImageNet-R"]):
        dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image)
    else:
        raise Exception(f"Eval test dataset type {options.eval_data_type} is not supported")

    dataloader = torch.utils.data.DataLoader(dataset, batch_size = options.batch_size, num_workers = options.num_workers, sampler = None)
    dataloader.num_samples = len(dataset)
    dataloader.num_batches = len(dataloader)

    options.add_backdoor = add_backdoor
    options.asr = asr
    options.patch_location = patch_location
    options.patch_size = patch_size
    options.patch_type = patch_type

    return dataloader

def get_eval_train_dataloader(options, processor):
    # if(not options.linear_probe or not options.finetune or options.eval_train_data_dir is None): return
    if(options.eval_train_data_dir is None): return

    if(options.eval_data_type == "Caltech101"):
        dataset = ImageLabelDataset(root = options.eval_train_data_dir, transform = processor.process_image)
    elif(options.eval_data_type == "CIFAR10"):
        dataset = torchvision.datasets.CIFAR10(root = os.path.dirname(options.eval_train_data_dir), download = True, train = True, transform = processor.process_image)
    elif(options.eval_data_type == "CIFAR100"):
        dataset = torchvision.datasets.CIFAR100(root = os.path.dirname(options.eval_test_data_dir), download = True, train = True, transform = processor.process_image)
    elif(options.eval_data_type == "DTD"):
        dataset = torch.utils.data.ConcatDataset([torchvision.datasets.DTD(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image), torchvision.datasets.DTD(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "val", transform = processor.process_image)])
    elif(options.eval_data_type == "FGVCAircraft"):
        dataset = torchvision.datasets.FGVCAircraft(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "trainval", transform = processor.process_image)
    elif(options.eval_data_type == "Flowers102"):
        dataset = ImageLabelDataset(root = options.eval_train_data_dir, transform = processor.process_image)
    elif(options.eval_data_type == "Food101"):
        dataset = torchvision.datasets.Food101(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image)
    elif(options.eval_data_type == "GTSRB"):
        dataset = torchvision.datasets.GTSRB(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image)
    elif(options.eval_data_type == "ImageNet1K"):
        options.add_backdoor = False
        dataset = ImageLabelDataset(root = options.eval_train_data_dir, transform = processor.process_image, options = options)
    elif(options.eval_data_type == "OxfordIIITPet"):
        dataset = torchvision.datasets.OxfordIIITPet(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "trainval", transform = processor.process_image)
    elif(options.eval_data_type == "RenderedSST2"):
        dataset = torchvision.datasets.RenderedSST2(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image)
    elif(options.eval_data_type == "StanfordCars"):
        dataset = torchvision.datasets.StanfordCars(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image)
    elif(options.eval_data_type == "STL10"):
        dataset = torchvision.datasets.STL10(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image)
    elif(options.eval_data_type == "SVHN"):
        dataset = torchvision.datasets.SVHN(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image)
    else:
        raise Exception(f"Eval train dataset type {options.eval_data_type} is not supported")

    dataloader = torch.utils.data.DataLoader(dataset, batch_size = options.linear_probe_batch_size, num_workers = options.num_workers, sampler = None, shuffle = True)
    dataloader.num_samples = len(dataset)
    dataloader.num_batches = len(dataloader)

    return dataloader

def load(options, processor):
    data = {}
    if options.TAC_train:
        data["TAC_train"] = get_TAC_train_dataloaer(options, processor)
        data["train"] = get_train_dataloader(options, processor)
    else:
        data["train"] = get_train_dataloader(options, processor)
        data["TAC_train"] = None
    data["validation"] = get_validation_dataloader(options, processor)
    data["eval_test"] = get_eval_test_dataloader(options, processor)
    data["eval_test_benign"]= get_eval_test_benign_dataloader(options, processor)
    data["eval_train"] = get_eval_train_dataloader(options, processor)
    data["patch_train"] = get_patch_train_dataloader(options, processor)

    return data